# Modified from: https://github.com/dlwh/jax_sourceror
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import ast
import enum
import warnings
from collections.abc import MutableMapping, MutableSet
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Callable, Union

import jax
import jax.numpy as jnp
import numpy as np
from jax._src.sharding_impls import UNSPECIFIED

if jax.__version__ >= '0.5.0':
    from jax.extend.core import Literal, Var, Jaxpr
else:
    from jax.core import Primitive, Literal, Var, Jaxpr

__all__ = [
    'fn_to_python_code',
    'jaxpr_to_python_code',
]


class IdentitySet(MutableSet):
    """Set that compares objects by identity.

    This is a set that compares objects by identity instead of equality. It is
    useful for storing objects that are not hashable or that should be compared
    by identity.

    This is a mutable set, but it does not support the ``__hash__`` method and
    therefore cannot be used as a dictionary key or as an element of another set.
    """

    def __init__(self, iterable=None):
        self._data = {}
        if iterable is not None:
            self.update(iterable)

    def __contains__(self, value):
        return id(value) in self._data

    def __iter__(self):
        return iter(self._data.values())

    def __len__(self):
        return len(self._data)

    def add(self, value):
        self._data[id(value)] = value

    def discard(self, value):
        self._data.pop(id(value), None)

    def __repr__(self):
        return f"IdentitySet({list(repr(x) for x in self._data.values())})"

    def __str__(self):
        return f"IdentitySet({list(str(x) for x in self._data.values())})"


class IdentityMap(MutableMapping):
    """Map that compares keys by identity.

    This is a map that compares keys by identity instead of equality. It is
    useful for storing objects that are not hashable or that should be compared
    by identity.

    This is a mutable mapping, but it does not support the ``__hash__`` method
    and therefore cannot be used as a dictionary key or as an element of another
    set.
    """

    def __init__(self, iterable=None):
        self._data = {}
        if iterable is not None:
            self.update(iterable)

    def __contains__(self, key):
        return id(key) in self._data

    def __getitem__(self, key):
        return self._data[id(key)]

    def __setitem__(self, key, value):
        self._data[id(key)] = value

    def __delitem__(self, key):
        del self._data[id(key)]

    def __iter__(self):
        return iter(self._data.values())

    def __len__(self):
        return len(self._data)

    def __repr__(self):
        return f"IdentityMap({list(repr(x) for x in self._data.values())})"

    def __str__(self):
        return f"IdentityMap({list(str(x) for x in self._data.values())})"


@dataclass
class SourcerorState:
    """State for the auto-minimizer. Basically just in charge of naming variables."""
    _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap)
    _skolem_count: int = 0

    def name(self, var, ctx=ast.Load()) -> ast.Name:
        return ast.Name(id=self.str_name(var), ctx=ctx)

    def str_name(self, var: Var):
        # Names things in a way vaguely compatible with
        # JAX's naming scheme, which is 'a'-'z' followed
        # by 'aa'-'az' etc.
        if var in self._var_names:
            return self._var_names[var]
        else:
            cur_count = len(self._var_names)
            name = ""
            while cur_count >= 26:
                name += chr(ord('a') + cur_count % 26)
                cur_count //= 26

            name += chr(ord('a') + cur_count)

            name = name[::-1]

            self._var_names[var] = name

            return name

    def skolem(self, prefix: str):
        self._skolem_count += 1
        return f"{prefix}_{self._skolem_count}"


prefix_imports = set()


@contextmanager
def catch_imports():
    try:
        prefix_imports.clear()
        yield
    finally:
        prefix_imports.clear()


def fn_to_python_code(fn, *args, **kwargs):
    """
    Given a function which is defined by jax primitives and the function arguments,
    return the Python code that would be generated by JAX for that function.

    :param fn: The function to generate code for
    :param args: The positional arguments to the function
    :param kwargs: The keyword arguments to the function
    :return: The Python code that would be generated by JAX for that function
    """
    closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs)
    jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr)
    state = SourcerorState()
    try:
        name = fn.__name__
    except AttributeError:
        name = "unknown"
    with catch_imports():
        node = jaxpr_to_py_ast(state, jaxpr, fn_name=name)
        node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs))
        ast.fix_missing_locations(node)
        source = ast.unparse(node)
        if len(prefix_imports):
            source = "\n".join(prefix_imports) + "\n\n" + source
    return source


def jaxpr_to_python_code(jaxpr: Jaxpr,
                         fn_name: str = "generated_function"):
    """
    Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.

    :param jaxpr: The jaxpr to generate code.
    :param fn_name: The name of the function to generate code.
    :return: The Python code that would be generated by JAX for that jaxpr
    """
    jaxpr = constant_fold_jaxpr(jaxpr)
    state = SourcerorState()
    with catch_imports():
        node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name)
        ast.fix_missing_locations(node)
        source = ast.unparse(node)
        if len(prefix_imports):
            source = "\n".join(prefix_imports) + "\n\n" + source
    return source


def register_prim_handler(prim_name, handler):
    """
    Register a handler for a primitive for automin
    :param prim_name:
    :param handler:
    :return:
    """
    if prim_name in prim_to_python:
        warnings.warn(f"Overwriting handler for primitive {prim_name}")
    prim_to_python[prim_name] = handler


def register_prim_as(prim_name):
    """
    Decorator to register a handler for a primitive.

    :param prim_name:
    :return:
    """

    def decorator(fn):
        register_prim_handler(prim_name, fn)
        return fn

    return decorator


def _assign_stmt(call_expr: Callable):
    """
    Create a handler for a primitive that is a simple assignment.
    :param call_expr:
    :return:
    """

    def binop_fn(state, eqn):
        invars = [_astify_atom(state, v) for v in eqn.invars]
        outvars = _astify_outvars(state, eqn.outvars)
        return ast.Assign(
            outvars,
            call_expr(
                *invars,
                **{k: _astify_value(v) for k, v in eqn.params.items()}
            )
        )

    return binop_fn


def _binop_fn(op: ast.operator):
    return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y))


def _cmpop_fn(op: ast.cmpop):
    return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y]))


def normal_fn(fn_name):
    """
    Create a handler for a normal function call.
    :param fn_name:
    :return:
    """
    return _assign_stmt(
        lambda *args, **kwargs: ast.Call(
            func=ast.Name(id=fn_name, ctx=ast.Load()),
            args=list(args),
            keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()]
        )
    )


def _reduce_fn(fn_name: str):
    def reduce_fn_inner(state: SourcerorState, eqn):
        invars = [_astify_atom(state, v) for v in eqn.invars]
        outvars = _astify_outvars(state, eqn.outvars)
        if eqn.params:
            params = eqn.params.copy()
            params['axis'] = tuple(params['axes'])
            del params['axes']
            call_op = ast.Call(
                func=ast.Name(id=fn_name, ctx=ast.Load()),
                args=invars,
                keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()]
            )
        else:
            call_op = ast.Call(
                func=ast.Name(id=fn_name, ctx=ast.Load()),
                args=invars,
                keywords=[]
            )

        return ast.Assign(outvars, call_op)

    return reduce_fn_inner


prim_to_python = dict()

register_prim_handler('add', _binop_fn(ast.Add()))
register_prim_handler('sub', _binop_fn(ast.Sub()))
register_prim_handler('mul', _binop_fn(ast.Mult()))
register_prim_handler('div', _binop_fn(ast.Div()))
register_prim_handler('neg', normal_fn('jax.lax.neg'))
register_prim_handler('lt', _cmpop_fn(ast.Lt()))
register_prim_handler('gt', _cmpop_fn(ast.Gt()))
register_prim_handler('le', _cmpop_fn(ast.LtE()))
register_prim_handler('ge', _cmpop_fn(ast.GtE()))
register_prim_handler('eq', _cmpop_fn(ast.Eq()))
register_prim_handler('ne', _cmpop_fn(ast.NotEq()))
register_prim_handler('min', normal_fn('jax.lax.min'))
register_prim_handler('max', normal_fn('jax.lax.max'))
register_prim_handler('select_n', normal_fn('jax.lax.select_n'))
register_prim_handler('squeeze', normal_fn('jax.lax.squeeze'))
register_prim_handler('broadcast', normal_fn('jax.lax.broadcast'))
register_prim_handler('reduce_sum', _reduce_fn('jax.numpy.sum'))
register_prim_handler('transpose', normal_fn('jax.lax.transpose'))


def _maybe_wrap_fn_for_leaves(node, f, num_args):
    if len(node.args.args) == num_args:
        return node

    wrapped_node = ast.FunctionDef(
        name=f.__name__,
        args=ast.arguments(
            args=[],
            vararg=ast.arg(arg="args", annotation=None),
            kwarg=ast.arg(arg="kwargs", annotation=None),
            kwonlyargs=[], kw_defaults=[], defaults=[],
            posonlyargs=[]
        ),
        body=[
            node,
            ast.Return(
                ast.Call(
                    func=ast.Name(id=node.name, ctx=ast.Load()),
                    args=[
                        ast.Starred(
                            ast.Call(
                                func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()),
                                                   attr="tree_leaves",
                                                   ctx=ast.Load()),
                                args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()),
                                                      ast.Name(id="kwargs", ctx=ast.Load())],
                                                ctx=ast.Load())],
                                keywords=[]
                            )
                        )
                    ],
                    keywords=[]
                )
            ),
        ],
        decorator_list=[]
    )

    return wrapped_node


def jaxpr_to_py_ast(state: SourcerorState,
                    jaxpr: Jaxpr,
                    fn_name: str = "function"):
    # Generate argument declarations
    ast_args = [ast.arg(arg=state.str_name(var), annotation=None)
                for var in jaxpr.invars]
    ast_args = ast.arguments(args=ast_args,
                             vararg=None,
                             kwonlyargs=[],
                             kw_defaults=[],
                             kwarg=None,
                             defaults=[],
                             posonlyargs=[])

    stmts = []

    # Generate body of the function
    for eqn in jaxpr.eqns:
        prim = str(eqn.primitive)
        if prim in prim_to_python:
            eqn_stmts = prim_to_python[prim](state, eqn)
        else:
            eqn_stmts = normal_fn(prim)(state, eqn)

        if isinstance(eqn_stmts, list):
            stmts.extend(eqn_stmts)
        else:
            stmts.append(eqn_stmts)

    # Generate return statement
    if len(jaxpr.outvars) == 1:
        returns = state.name(jaxpr.outvars[0])
    else:
        returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load())
    stmts.append(ast.Return(value=returns))

    return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[])


def constant_fold_jaxpr(jaxpr: Jaxpr):
    """
    Given a jaxpr, return a new jaxpr with all constant folding done.
    """
    return partial_eval_jaxpr(jaxpr, {})


def partial_eval_jaxpr(jaxpr, env):
    env = env.copy()
    new_eqns = []

    def read(var):
        if isinstance(var, Literal):
            return var.val
        else:
            return env.get(var, None)

    def read_or_self(var):
        out = read(var)
        if out is None:
            return var
        elif isinstance(out, Var):
            return out
        elif isinstance(out, Literal):
            return Literal(out.val, var.aval)
        else:
            assert not isinstance(out, Jaxpr)
            return Literal(out, var.aval)

    for eqn in jaxpr.eqns:
        vals = [read(var) for var in eqn.invars]
        if eqn.primitive.name in constant_fold_blacklist:
            new_eqns.append(eqn)
        elif all(val is not None for val in vals):
            # go ahead and eval it
            out = _eval_eqn(eqn, vals)

            # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
            if isinstance(out, Jaxpr):
                # we need to inline this
                new_eqns.extend(out.eqns)
                out = out.outvars
            elif not isinstance(out, tuple) and not isinstance(out, list):
                out = (out,)

            for var, val in zip(eqn.outvars, out):
                assert not isinstance(val, Jaxpr)
                if isinstance(val, Literal):
                    env[var] = val.val
                else:
                    env[var] = val
        else:
            new_eqns.append(eqn)

    # now that we've evaled everything, inline all the constants
    out_eqns = []
    for eqn in new_eqns:
        eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
        out_eqns.append(eqn)

    invars_still_used = IdentitySet()
    for eqn in out_eqns:
        for var in eqn.invars:
            invars_still_used.add(var)

    invars = tuple(var for var in jaxpr.invars if var in invars_still_used)

    # sub in any constants for outvars
    outvars = tuple(read_or_self(var) for var in jaxpr.outvars)

    return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars)


def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]:
    if eqn.primitive.name == "closed_call":
        assert eqn.primitive.call_primitive == True
        assert eqn.primitive.map_primitive == False

        out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
                                 {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
    elif eqn.primitive.name == "scan":
        out = eqn.primitive.bind(*vals, **eqn.params)
    else:
        out = eqn.primitive.bind(*vals, **eqn.params)
    return out


@register_prim_as('dot_general')
def _astify_dot_general(state, eqn):
    x, y = eqn.invars
    d = eqn.params['dimension_numbers']
    precision = eqn.params['precision']
    preferred_element_type = eqn.params['preferred_element_type']

    has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type

    # recognize simple matmul case
    if d == (((1,), (0,)), ((), ())) and precision == None:
        invars = [_astify_atom(state, x), _astify_atom(state, y)]
        outvars = _astify_outvars(state, eqn.outvars)
        out = ast.Assign(targets=outvars, value=ast.Call(
            func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()),
            args=invars,
            keywords=[]))
        if not has_dtype:
            out = ast.Assign(targets=outvars,
                             value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()),
                                            args=[_astify_value(preferred_element_type)], keywords=[]))

        return out

    # TODO: convert to einsum?

    invars = [_astify_atom(state, x),
              _astify_atom(state, y),
              _astify_value(d),
              _astify_value(precision),
              _astify_value(preferred_element_type)]
    outvars = _astify_outvars(state, eqn.outvars)
    return ast.Assign(
        targets=outvars,
        value=ast.Call(
            func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()),
            args=invars,
            keywords=[]
        )
    )


@register_prim_as('dynamic_slice')
def _sourcify_dynamic_slice(state, eqn):
    sliced = eqn.invars[0]
    invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load())
    outvars = _astify_outvars(state, eqn.outvars)
    params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()]
    return ast.Assign(
        targets=outvars,
        value=ast.Call(
            func=ast.Attribute(
                value=ast.Name(id='jax.lax', ctx=ast.Load()),
                attr='dynamic_slice',
                ctx=ast.Load()
            ),
            args=[_astify_atom(state, sliced), invars],
            keywords=params
        )
    )


@register_prim_as('slice')
def _sourcify_slice(state, eqn):
    sliced = eqn.invars[0]
    # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load())
    outvars = _astify_outvars(state, eqn.outvars)
    start_indices = eqn.params['start_indices']
    limit_indices = eqn.params['limit_indices']
    strides = eqn.params['strides']
    if strides is None:
        strides = (None,) * len(start_indices)
    indices = [_astify_value(slice(s, e, stride))
               for s, e, stride in zip(start_indices, limit_indices, strides)]
    # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()]
    return ast.Assign(
        targets=outvars,
        value=ast.Subscript(
            value=_astify_atom(state, sliced),
            slice=ast.Tuple(elts=indices, ctx=ast.Load()),
            ctx=ast.Load()
        )
    )


@register_prim_as('dynamic_update_slice')
def _sourcify_dynamic_update_slice(state, eqn):
    sliced = eqn.invars[0]
    # the first two arguments are the sliced array and the update array
    # the remaining are start indices and should be packaged into a tuple
    target = _astify_atom(state, eqn.invars[0])
    update = _astify_atom(state, eqn.invars[1])
    start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]])
    outvars = _astify_outvars(state, eqn.outvars)

    return ast.Assign(targets=outvars, value=ast.Call(
        func=ast.Attribute(
            value=ast.Name(id='jax.lax', ctx=ast.Load()),
            attr='dynamic_update_slice',
            ctx=ast.Load()
        ),
        args=[target, update, start_indices],
        keywords=[]
    ))


@register_prim_as('convert_element_type')
def _astify_convert_element_type(state, eqn):
    # now we use ast
    outvars = _astify_outvars(state, eqn.outvars)
    assert len(eqn.invars) == 1
    invar = _astify_atom(state, eqn.invars[0])
    dtype = _astify_value(eqn.params['new_dtype'])
    return ast.Assign(targets=outvars, value=ast.Call(
        func=ast.Attribute(
            value=invar,
            attr='astype',
            ctx=ast.Load()
        ),
        args=[dtype],
        keywords=[]
    ))


def is_array(arr):
    return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray))


def _astify_array(value):
    assert is_array(value)
    if isinstance(value, np.int64):
        return ast.Constant(value=int(value))

    if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64):
        return ast.Constant(value=value.item())

    if value.ndim == 0:
        dtype_value = _astify_value(value.dtype)
        return ast.Call(
            dtype_value,
            args=[ast.Constant(value=value.item())],
            keywords=[],
        )

    values = value.tolist()

    def rec_astify_list(values):
        if isinstance(values, list):
            return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load())
        else:
            return ast.Constant(value=values)

    return ast.Call(
        func=ast.Attribute(
            value=ast.Name(id='jax.numpy', ctx=ast.Load()),
            attr='array',
            ctx=ast.Load()
        ),
        args=[rec_astify_list(values)],
        keywords=[ast.keyword(arg='dtype',
                              value=_astify_value(value.dtype))]
    )


def _astify_atom(state: SourcerorState, var: Union[Literal, Var]):
    if isinstance(var, Literal):
        return _astify_value(var.val)
    elif isinstance(var, Var):
        return state.name(var)
    else:
        raise NotImplementedError()


def _astify_value(value):
    assert not isinstance(value, (Literal, Var))

    if is_array(value):
        return _astify_array(value)
    elif isinstance(value, (int, bool, float, str, type(None))):
        return ast.Constant(value=value)
    elif isinstance(value, (tuple, list)):
        return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load())
    elif isinstance(value, jnp.dtype):
        # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[])
        if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'):
            # return ast.Constant(value=getattr(jnp, value.name))
            return ast.Attribute(
                value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                attr=value.name,
                ctx=ast.Load()
            )
        elif value.name == 'bool':
            return ast.Attribute(
                value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                attr='bool_',
                ctx=ast.Load()
            )
        else:
            return ast.Call(
                func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                                   attr='dtype',
                                   ctx=ast.Load()),
                args=[ast.Constant(value=str(value))],
                keywords=[]
            )
    elif value is UNSPECIFIED:
        prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED')
        return ast.Name(id='UNSPECIFIED', ctx=ast.Load())
    elif isinstance(value, enum.Enum):
        return ast.Attribute(
            value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()),
            attr=value.name,
            ctx=ast.Load()
        )

    else:
        warnings.warn(f"Unknown value type {type(value)}")
        return ast.parse(repr(value)).body[0]


def _astify_outvars(state, outvars):
    out = [state.name(v, ctx=ast.Store()) for v in outvars]
    if len(out) == 1:
        return out
    else:
        return [ast.Tuple(elts=out, ctx=ast.Store())]


def maybe_tuple_vars(vars):
    if len(vars) == 1:
        return vars[0]
    else:
        return ast.Tuple(elts=vars, ctx=ast.Load())


def maybe_untuple_vars(var, is_tuple):
    if is_tuple:
        return ast.Starred(value=var, ctx=ast.Load())
    else:
        return var


@register_prim_as('scan')
def _astify_scan(state, eqn):
    assert eqn.primitive.name == 'scan'

    # the args to scan are [constants, carry, xs]
    # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda)
    num_consts = eqn.params['num_consts']
    num_carry = eqn.params['num_carry']

    # TODO: bring back map
    # if num_carry == 0:
    # this is a map
    # return _astify_map(eqn)

    constant_args = eqn.invars[:num_consts]
    carries = eqn.invars[num_consts:num_consts + num_carry]
    xs = eqn.invars[num_consts + num_carry:]

    jaxpr = eqn.params['jaxpr']

    if num_consts != 0:
        # we want to construct an environment where we partial eval the function using the constants as the env
        env = dict(zip(jaxpr.jaxpr.invars, constant_args))
        jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env)
    else:
        jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)

    fn_name = state.skolem('fn')
    fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)

    length = _astify_value(eqn.params['length'])
    unroll = _astify_value(eqn.params['unroll'])
    reverse = _astify_value(eqn.params['reverse'])

    stmts = []

    if num_carry != 1 or len(jaxpr.invars) != 2:
        # what we want is something like:
        # fn_name = lambda carry, xs: fn_name(*carry, *xs)
        # jax.lax.scan(fn_name, (carries...), (xs...))

        modified_signature = ast.arguments(
            args=[ast.arg(arg='carry'), ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[],
            posonlyargs=[]
        )

        initial_assign = ast.Assign(
            targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args],
                               ctx=ast.Store())],
            value=ast.Tuple(
                elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1),
                      maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)]
            )
        )

        fn_return = fn_ast.body[-1]
        assert isinstance(fn_return, ast.Return)

        fn_return_value = fn_return.value

        if isinstance(fn_return_value, ast.Tuple):
            fn_return_value = fn_return_value.elts
            ret_carries = maybe_tuple_vars(fn_return_value[:num_carry])
            ret_ys = maybe_tuple_vars(fn_return_value[num_carry:])
        elif num_carry == 0:
            ret_carries = _astify_value(())
            ret_ys = fn_return_value
        else:
            ret_carries = fn_return_value
            ret_ys = _astify_value(())

        scan_return = ast.Return(
            value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load())
        )

        new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return]

        fn_ast = ast.FunctionDef(
            name=fn_name,
            args=modified_signature,
            body=new_body,
            decorator_list=[]
        )

        stmts.append(fn_ast)

        scan_call = ast.Assign(
            # targets=_astify_outvars(eqn.outvars),
            targets=[
                ast.Tuple(
                    elts=[ast.Name(id='final_carry', ctx=ast.Store()),
                          ast.Name(id='ys', ctx=ast.Store())],
                    ctx=ast.Store()
                )
            ],
            value=ast.Call(
                func=ast.Name(id='jax.lax.scan', ctx=ast.Load()),
                args=[ast.Name(id=fn_name, ctx=ast.Load()),
                      maybe_tuple_vars([_astify_atom(state, v) for v in carries]),
                      maybe_tuple_vars([_astify_atom(state, v) for v in xs])],
                keywords=[ast.keyword(arg='length', value=length),
                          ast.keyword(arg='unroll', value=unroll),
                          ast.keyword(arg='reverse', value=reverse)]
            )
        )
        stmts.append(scan_call)

        if num_carry > 0:
            assign_carry = ast.Assign(
                targets=_astify_outvars(state, eqn.outvars[:num_carry]),
                value=ast.Name(id='final_carry', ctx=ast.Load())
            )

            stmts.append(assign_carry)

        if num_carry < len(eqn.outvars):
            assign_ys = ast.Assign(
                targets=_astify_outvars(state, eqn.outvars[num_carry:]),
                value=ast.Name(id='ys', ctx=ast.Load())
            )

            stmts.append(assign_ys)
    else:
        stmts.append(fn_ast)

        scan_call = ast.Assign(
            targets=_astify_outvars(state, eqn.outvars),
            value=ast.Call(
                func=ast.Name(id='jax.lax.scan', ctx=ast.Load()),
                args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars],
                keywords=[ast.keyword(arg='length', value=length),
                          ast.keyword(arg='unroll', value=unroll),
                          ast.keyword(arg='reverse', value=reverse)]
            )
        )

        stmts.append(scan_call)

    return stmts


def _astify_map(state, eqn):
    assert eqn.primitive.name == 'scan'
    assert eqn.params['num_carry'] == 0

    jaxpr = eqn.params['jaxpr']
    jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)

    fn_name = state.skolem('fn')
    fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)

    # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg
    # so we need to use a lambda to redirect the call
    lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0]

    assign = ast.Assign(
        targets=_astify_outvars(state, eqn.outvars),
        value=ast.Call(
            func=ast.Name(id='jax.lax.map', ctx=ast.Load()),
            args=[lam,
                  ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars],
                            ctx=ast.Load())],
            keywords=[]
        )
    )

    return [fn_ast, assign]


@register_prim_as('closed_call')
def _astify_closed_call(state, eqn):
    # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
    #                          {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
    raw_jaxpr = eqn.params['call_jaxpr'].jaxpr
    literal_args = {k: v.val
                    for k, v in zip(raw_jaxpr.invars, eqn.invars)
                    if isinstance(v, Literal)}
    call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args)
    fn_name = state.skolem('fn')

    fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name)

    invars = [_astify_atom(state, v)
              for v in eqn.invars
              if not isinstance(v, Literal)]
    outvars = _astify_outvars(state, eqn.outvars)

    assign = ast.Assign(
        targets=outvars,
        value=ast.Call(
            func=ast.Name(id=fn_name, ctx=ast.Load()),
            args=invars,
            keywords=[]
        )
    )

    return [fn_ast, assign]


@register_prim_as('pjit')
def _astify_pjit(state, eqn):
    # this one's a real pain.
    # pjit's params are :
    # jaxpr
    # donated_invars:
    # in_shardings, out_shardings
    # resource env
    # name (yay)
    # keep_unused, inline (which we won't use)

    jaxpr = eqn.params['jaxpr']
    donated_invars = eqn.params['donated_invars']
    in_shardings = eqn.params['in_shardings']
    out_shardings = eqn.params['out_shardings']
    resource_env = eqn.params['resource_env']
    name = eqn.params['name']

    can_ignore_donated = not any(donated_invars)

    # preprocess the function
    jaxpr = constant_fold_jaxpr(jaxpr.jaxpr)
    fn_name = state.skolem(name)
    fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name)

    in_shardings = _astify_value(in_shardings)
    out_shardings = _astify_value(out_shardings)

    keywords = [
        ast.keyword(arg='in_shardings', value=in_shardings),
        ast.keyword(arg='out_shardings', value=out_shardings),
    ]

    if not can_ignore_donated:
        donated_invars = _astify_value(donated_invars)
        keywords.append(ast.keyword(arg='donated_invars', value=donated_invars))

    jitted_fn = ast.Call(
        func=ast.Attribute(
            ast.Name(id='jax', ctx=ast.Load()),
            attr='jit'
        ),
        args=[ast.Name(id=fn_name, ctx=ast.Load())],
        keywords=keywords
    )

    assign = ast.Assign(
        targets=_astify_outvars(state, eqn.outvars),
        value=ast.Call(
            func=jitted_fn,
            args=[_astify_atom(state, v) for v in eqn.invars],
            keywords=[]
        )
    )

    return [fn_ast, assign]


@register_prim_as('remat2')
def _astify_remat(state: SourcerorState, eqn):
    # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr,
    #                          {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)})
    call_japr = constant_fold_jaxpr(eqn.params['jaxpr'])
    fn_name = state.skolem('fn')

    fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name)

    invars = [_astify_atom(state, v) for v in eqn.invars]
    outvars = _astify_outvars(state, eqn.outvars)

    lam = ast.Assign(
        targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())],
        # value=ast.parse(f"jax.checkpoint({fn_name})").body[0]
        value=ast.Call(
            func=ast.Name(id='jax.checkpoint', ctx=ast.Load()),
            args=[ast.Name(id=fn_name, ctx=ast.Load())],
            keywords=[])
    )

    assign = ast.Assign(
        targets=outvars,
        value=ast.Call(
            func=ast.Name(id=f"ckpt_{fn_name}"),
            args=invars,
            keywords=[]
        ))

    return [fn_ast, lam, assign]


@register_prim_as('reshape')
def _astify_reshape(state, eqn):
    # the lax reshape is a bit different, because it can combine a transpose and reshape into one.
    # np.reshape(np.transpose(operand, dimensions), new_sizes)
    dimensions = eqn.params['dimensions']
    new_sizes = eqn.params['new_sizes']

    source = _astify_atom(state, eqn.invars[0])

    if dimensions is not None:
        source = ast.Call(
            func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()),
            args=[source, _astify_value(dimensions)],
            keywords=[]
        )

    assign = ast.Assign(
        targets=_astify_outvars(state, eqn.outvars),
        value=ast.Call(
            func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()),
            args=[source, _astify_value(new_sizes)],
            keywords=[]
        ))

    return [assign]


@register_prim_as('add_any')
def _astify_add_any(state, eqn):
    # add_any is a weird undocumented jax primitive. best guess is it adds?
    return _binop_fn(ast.Add())(state, eqn)


@register_prim_as('broadcast_in_dim')
def _astify_broadcast_in_dim(state, eqn):
    # broadcast_in_dim is how zeros, ones, full, etc are implemented,
    # so we prefer to use those where possible
    assert len(eqn.invars) == 1
    value = eqn.invars[0]
    shape = eqn.params['shape']
    broadcast_dimensions = eqn.params['broadcast_dimensions']

    if not isinstance(value, Literal) or broadcast_dimensions != ():
        return normal_fn('jax.lax.broadcast_in_dim')(state, eqn)

    if not isinstance(value.val, np.ndarray) or value.val.ndim != 0:
        return normal_fn('jax.lax.broadcast_in_dim')(state, eqn)
    else:
        constant_value = value.val.item()
        if constant_value == 0:
            call = ast.Call(
                ast.Attribute(
                    value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                    attr='zeros',
                    ctx=ast.Load()
                ),
                args=[_astify_value(shape),
                      _astify_value(value.val.dtype)],
                keywords=[]
            )
        elif constant_value == 1:
            call = ast.Call(
                ast.Attribute(
                    value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                    attr='ones',
                    ctx=ast.Load()
                ),
                args=[_astify_value(shape),
                      _astify_value(value.val.dtype)],
                keywords=[]
            )
        else:
            call = ast.Call(
                ast.Attribute(
                    value=ast.Name(id='jax.numpy', ctx=ast.Load()),
                    attr='full',
                    ctx=ast.Load()
                ),
                args=[_astify_value(shape),
                      _astify_value(constant_value),
                      _astify_value(value.val.dtype)],
                keywords=[]
            )

        return [ast.Assign(
            targets=_astify_outvars(state, eqn.outvars),
            value=call
        )]


@register_prim_as('random_wrap')
def _astify_random_wrap(state, eqn):
    # we treat this as a noop
    return ast.Assign(
        targets=_astify_outvars(state, eqn.outvars),
        value=_astify_atom(state, eqn.invars[0])
    )


constant_fold_blacklist = {
    'broadcast_in_dim',
    'broadcast',
}
